from __future__ import annotations

import torch

from typing import Tuple
from argparse import Namespace

import utils
from client import Client

class Attacker:
    def __init__(self, byz_clients: dict[str, Client], args: Namespace) -> None:
        self.args = args
        self.byz_clients = byz_clients

    @torch.no_grad()
    def get_ref_updates(self, sampled_byz_client_idxs: set[str], server_message: dict, knowledge: dict) -> dict[str, dict[str, torch.Tensor]]:
        sampled_benign_client_messages: dict[str, dict] = knowledge['benign_client_messages']
        if len(sampled_benign_client_messages) > 0:
            ref_updates = {idx: sampled_benign_client_messages[idx]['update'] for idx in sampled_benign_client_messages}
        else:
            sampled_byz_client_ori_messages = {idx: self.byz_clients[idx].local_update(server_message) for idx in sampled_byz_client_idxs}
            ref_updates = {idx: sampled_byz_client_ori_messages[idx]['update'] for idx in sampled_byz_client_ori_messages}
        return ref_updates
    
    @torch.no_grad()
    def pack_byz_client_msgs(self, sampled_byz_client_idxs: set[str], byz_update: dict[str, torch.Tensor]) -> dict[str, dict]:
        byz_client_messages = {}
        for idx in sampled_byz_client_idxs:
            byz_client_message = self.byz_clients[idx].pack_other_message()
            byz_client_message['update'] = byz_update
            byz_client_messages[idx] = byz_client_message
        return byz_client_messages
        
    def attack(self, sampled_byz_client_idxs: set[str], server_message: dict, knowledge: dict) -> Tuple[dict[str, dict], dict]:
        verbose_log = {}

        ref_updates = self.get_ref_updates(sampled_byz_client_idxs, server_message, knowledge)
        flat_ref_updates, structure = utils.flatten_updates(ref_updates)

        flat_skew_dir = self.init_dev(flat_ref_updates)
        n_ben = len(knowledge['benign_client_messages'])
        n_byz = len(sampled_byz_client_idxs)
        if n_ben <= n_byz:
            flat_avg = flat_ref_updates.mean(dim=0)
            if n_ben == 1:
                flat_byz_update = -10 * flat_avg
            else:
                flat_byz_update = flat_avg + flat_skew_dir * 10
        else:
            n_skew = n_ben - n_byz
            flat_avg = flat_ref_updates.mean(dim=0)
            inner_product = flat_ref_updates @ flat_skew_dir
            _, skew_idxs = inner_product.topk(k=n_skew, sorted=False)
            flat_skew_updates = flat_ref_updates[skew_idxs]
            verbose_log['majority_client_idxs'] = skew_idxs.tolist()
            
            flat_skew_avg = flat_skew_updates.mean(dim=0)
            flat_dev = (flat_skew_avg - flat_avg).sign() * flat_skew_updates.std(dim=0, unbiased=False)
            skew_diameter = torch.cdist(flat_skew_updates, flat_skew_updates).max().item()
            def f(s: float):
                flat_byz_update = flat_skew_avg + s * flat_dev
                dists: torch.Tensor = (flat_byz_update - flat_skew_updates).norm(dim=-1)
                max_dist = dists.max().item()
                return max_dist - skew_diameter
            max_s = 10.0
            s = utils.bisection(0.0, max_s, 1e-5, f)
            strength = self.args.skew_lambda * s
            flat_byz_update = flat_skew_avg + strength * flat_dev
            verbose_log['strength'] = strength

        byz_update = utils.unflatten_update(flat_byz_update, structure)
        sampled_byz_client_msgs = self.pack_byz_client_msgs(sampled_byz_client_idxs, byz_update)

        return sampled_byz_client_msgs, verbose_log

    def init_dev(self, flat_ref_updates: torch.Tensor):
        flat_avg = flat_ref_updates.mean(dim=0)
        flat_med, _ = flat_ref_updates.median(dim=0)
        flat_dir = flat_med - flat_avg
        return flat_dir